library(here)
setwd(here())

load(file = "data/intermediate/prom_sum.rda")
load(file = "data/intermediate/mii.lookup.rda")
load(file = "data/intermediate/promise.index.rda")
load("data/intermediate/pf.rda")

library(haven)
library(reshape2)
library(mellonMisc)
library(dplyr)
library(glmnet)
library(ridge)
library(brms)
library(readxl)
library(metRology)
library(survey)
library(plm)
library(plotly)
library(ggplot2)
library(readr)
load(file = "data/intermediate/models_multivariate.Rdata")
set.seed(234982348)

samples <- posterior_samples(fit.all, "mmg1g2g3g4|mmmii1mii2mii3mii4")
imp.samples <- samples[, paste0("r_", "mmg1g2g3g4__import", "[", rownames(promise.index), ",Intercept]")]
mii.imp.samples <- samples[, paste0("r_", "mmmii1mii2mii3mii4__import", "[", mii.lookup$code, ",Intercept]")] 
promise.index$text[promise.index$text=="Create a 'Veterans Board' to co-ordinate treatment of military veterans"] <- 
  "Create a 'Veterans Board' to coordinate treatment of military veterans"
promise.index$text[promise.index$text== "Keep the Trident nuclear deterrant"  ] <- 
  "Keep the Trident nuclear deterrent"  


promise.index$mii <- pf$agreed_cat_string[match(promise.index$text, pf$Snippets)]
promise.index$mii.code <- mii.lookup$code[match(promise.index$mii, mii.lookup$cat)]
imp.samples <- imp.samples + mii.imp.samples[, promise.index$mii.code] 

samples.c <- posterior_samples(fit.con, "mmg1g2g3g4|mmmii1mii2mii3mii4")

imp.samples.c <- samples.c[, paste0("r_", "mmg1g2g3g4__import", "[", rownames(promise.index), ",Intercept]")]
imp.samples.c.vc <- samples.c[, paste0("r_", "mmg1g2g3g4__import", "[", rownames(promise.index), ",votecon]")]
imp.samples.c <- imp.samples.c + imp.samples.c.vc

mii.imp.samples.c <- samples.c[, paste0("r_", "mmmii1mii2mii3mii4__import", "[", mii.lookup$code, ",Intercept]")]
mii.imp.samples.c.vc <- samples.c[, paste0("r_", "mmmii1mii2mii3mii4__import", "[", mii.lookup$code, ",votecon]")]
mii.imp.samples.c <- mii.imp.samples.c + mii.imp.samples.c.vc
imp.samples.c <- imp.samples.c + mii.imp.samples.c[, promise.index$mii.code] 

library(mixtools)

imp.simple.means <- colMeans(imp.samples)
save(imp.simple.means, file = "data/intermediate/imp.simple.means.rda")


mix.mod.all <- normalmixEM(colMeans(imp.samples), maxit = 10000)
mix.mod.con <- normalmixEM(colMeans(imp.samples.c), maxit = 10000)

group_assign <- dtf(promise = 1:ncol(imp.samples),
                    text = pf$Snippets, 
                    majorPromiseProb = mix.mod.all$posterior[,which.min(colSums(mix.mod.all$posterior))],
                    majorPromiseProb.c = mix.mod.con$posterior[,which.min(colSums(mix.mod.con$posterior))])


prom.sum$classprop <- (group_assign$majorPromiseProb[match(prom.sum$g1, group_assign$promise)] + 
                         group_assign$majorPromiseProb[match(prom.sum$g2, group_assign$promise)] +
                         group_assign$majorPromiseProb[match(prom.sum$g3, group_assign$promise)] +
                         group_assign$majorPromiseProb[match(prom.sum$g4, group_assign$promise)] ) / 4

prom.sum$classprop.c <- (group_assign$majorPromiseProb.c[match(prom.sum$g1, group_assign$promise)] + 
                           group_assign$majorPromiseProb.c[match(prom.sum$g2, group_assign$promise)] +
                           group_assign$majorPromiseProb.c[match(prom.sum$g3, group_assign$promise)] +
                           group_assign$majorPromiseProb.c[match(prom.sum$g4, group_assign$promise)] ) / 4


source("scripts/promise_functions.R")
set.seed(13824)
fit <-  bf(import|weights(weight) ~  classprop + 
             (1 |p|mm(g1, g2, g3, g4)) +
             (1 |q|mm(mii1, mii2, mii3, mii4)) +
             (1| id) + (1|page))

fit.approv <-  bf(approval|weights(weight) ~   
                    (1 |p|mm(g1, g2, g3, g4)) +
                    (1 |q|mm(mii1, mii2, mii3, mii4)) +
                    (1|id) + (1|page))


get_prior(formula = fit + fit.approv, data = prom.sum)
fit.all.fmm <- brm(fit + fit.approv, data = prom.sum, 
                   cores = 4, prior = c(set_prior('exponential(0.7)', class = "sd", resp = "approval"),
                                        set_prior('exponential(0.7)', class = "sd", resp = "import"),
                                        set_prior('normal(0, 5)', class = "b", resp = "import"),
                                        set_prior('lkj(2)', class = "cor")) , 
                   seed = 2384, iter = 4000, warmup = 2000)

save.image(file = "data/intermediate/models_multivariatefmm.Rdata")
load(file = "data/intermediate/models_multivariatefmm.Rdata")


library(haven)
library(reshape2)
library(mellonMisc)
library(dplyr)
library(glmnet)
library(ridge)
library(brms)
library(readxl)
library(metRology)
library(survey)
library(plm)
library(plotly)
library(ggplot2)
library(readr)
set.seed(2394239)
fit.c <-  bf(import|weights(weight) ~  votecon + 
               classprop + classprop.c +
               votecon:classprop + votecon:classprop.c +
               (1 +votecon|p|mm(g1, g2, g3, g4)) +
               (1 +votecon|q|mm(mii1, mii2, mii3, mii4)) +
               (1| id) + (1|page))

fit.approv.c <-  bf(approval|weights(weight) ~ votecon + 
                      (1 +votecon|p|mm(g1, g2, g3, g4)) +
                      (1 +votecon|q|mm(mii1, mii2, mii3, mii4)) +
                      (1|id) + (1|page))

get_prior(formula = fit.c + fit.approv.c, data = prom.sum)
fit.con.fmm <- brm(fit.c + fit.approv.c, data = prom.sum, 
               cores = 4, prior = c(set_prior('exponential(0.7)', class = "sd", resp = "approval"),
                                    set_prior('exponential(0.7)', class = "sd", resp = "import"),
                                    set_prior('normal(0, 5)', class = "b", resp = "approval"),
                                    set_prior('normal(0, 5)', class = "b", resp = "import"),
                                    set_prior('lkj(2)', class = "cor") ) ,
               seed = 43594, iter = 4000, warmup = 2000, control=list(adapt_delta=0.95))

save.image(file = "data/intermediate/models_multivariatefmm.Rdata")

#### Output ####

# Appendix D:
# Figure 4 Diagnostics for the first-stage model (all respondents)

fit.diags <- plot(fit.all)
fit.diags.c <- plot(fit.con, N=10)
fmm.diag.c <- plot(fit.con.fmm, N = 10)
fmm.diag <- plot(fit.all.fmm, N = 10)
saveForPub(fit.diags[[1]], file = "figures/diag_stage1_1")
saveForPub(fit.diags[[2]], file = "figures/diag_stage1_2")
saveForPub(fit.diags[[3]], file = "figures/diag_stage1_3")

#  Figure 5 Diagnostics for the first-stage model (2017 Conservative slopes)
saveForPub(fit.diags.c[[1]], file = "figures/diag_c_stage1_1", height = 12)
saveForPub(fit.diags.c[[2]], file = "figures/diag_c_stage1_2", height = 12)
saveForPub(fit.diags.c[[3]], file = "figures/diag_c_stage1_3", height = 12)
saveForPub(fit.diags.c[[4]], file = "figures/diag_c_stage1_4", height = 2)



# Figure 6 Diagnostics for the second-stage model (all respondents)
saveForPub(fmm.diag[[1]], file = "figures/diag_stage2_1", height = 12)
saveForPub(fmm.diag[[2]], file = "figures/diag_stage2_2", height = 12)

# Figure 7 Diagnostics for the second-stage model (Conservative slopes)
saveForPub(fmm.diag.c[[1]], file = "figures/diag_c_stage2_1", height = 12)
saveForPub(fmm.diag.c[[2]], file = "figures/diag_c_stage2_2", height = 12)
saveForPub(fmm.diag.c[[3]], file = "figures/diag_c_stage2_3", height = 12)
saveForPub(fmm.diag.c[[4]], file = "figures/diag_c_stage2_4", height = 12)